Count univalue subtrees

Time: O(N); Space: O(H); medium

Given a binary tree, count the number of uni-value subtrees.

A Uni-value subtree means all nodes of the subtree have the same value.

Example 1:

Input: root = [5,1,5,5,5,null,5]

    5
   / \
  1   5
 / \   \
5   5   5

Output: 4

[1]:
class TreeNode(object):
    def __init__(self, x):
        self.val = x
        self.left = None
        self.right = None
[3]:
class Solution1(object):
    """
    Time: O(N)
    Space: O(H)
    """
    def countUnivalSubtrees(self, root):
        """
        :type root: TreeNode
        :rtype: int
        """
        [is_uni, count] = self.isUnivalSubtrees(root, 0)
        return count

    def isUnivalSubtrees(self, root, count):
        if not root:
            return [True, count]

        [left, count] = self.isUnivalSubtrees(root.left, count)
        [right, count] = self.isUnivalSubtrees(root.right, count)
        if self.isSame(root, root.left, left) and \
           self.isSame(root, root.right, right):
            count += 1
            return [True, count]

        return [False, count]

    def isSame(self, root, child, is_uni):
        return not child or (is_uni and root.val == child.val)
[4]:
s = Solution1()

root = TreeNode(5)
root.left, root.right = TreeNode(1), TreeNode(5)
root.left.left, root.left.right = TreeNode(5), TreeNode(5)
root.right.right = TreeNode(5)
assert s.countUnivalSubtrees(root) == 4